import torch
from typing import Any
import numpy as np
import mujoco
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import functools
import dataclasses
from humenv import make_humenv
from humenv.rewards import RewardFunction


def get_next(field: str, data: Any):
    # import ipdb;ipdb.set_trace()
    if "next" in data and field in data["next"]:
        return data["next"][field]
    elif f"next_{field}" in data:
        return data[f"next_{field}"]
    else:
        raise ValueError(f"No next of {field} found in data.")


@dataclasses.dataclass(kw_only=True)
class BaseHumEnvBenchWrapper:
    model: Any
    numpy_output: bool = True
    _dtype: torch.dtype = dataclasses.field(default_factory=lambda: torch.float32)

    def act(
        self,
        obs: torch.Tensor | np.ndarray,
        z: torch.Tensor | np.ndarray,
        mean: bool = True,
    ) -> torch.Tensor:
        obs = to_torch(obs, device=self.device, dtype=self._dtype)
        z = to_torch(z, device=self.device, dtype=self._dtype)
        if self.numpy_output:
            return self.unwrapped_model.act(obs, z, mean).cpu().detach().numpy()
        return self.unwrapped_model.act(obs, z, mean)

    @property
    def device(self) -> Any:
        # this returns the base torch.nn.module
        return self.unwrapped_model.cfg.device

    @property
    def unwrapped_model(self):
        # this is used to call the base instance of model
        if hasattr(self.model, "unwrapped_model"):
            return self.model.unwrapped_model
        else:
            return self.model

    def __getattr__(self, name):
        # Delegate to the wrapped instance
        return getattr(self.model, name)


@dataclasses.dataclass(kw_only=True)
class RewardWrapper(BaseHumEnvBenchWrapper):
    inference_dataset: Any
    num_samples_per_inference: int
    inference_function: str
    max_workers: int
    process_executor: bool = False
    process_context: str = "spawn"

    def reward_inference(self, task: str) -> torch.Tensor:
        env, _ = make_humenv(task=task)
        if self.num_samples_per_inference < len(self.inference_dataset['observation']):
            idx= np.random.randint(0, len(self.inference_dataset['observation'])-1,size=self.num_samples_per_inference)
            data = {
                'observation': torch.Tensor(self.inference_dataset["observation"][idx]).to(self.device),
                "qpos": torch.Tensor(self.inference_dataset["qpos"][idx]).to(self.device),
                "qvel": torch.Tensor(self.inference_dataset["qvel"][idx]).to(self.device),
                "action": torch.Tensor(self.inference_dataset["action"][idx]).to(self.device),
                "next":{
                    "observation": torch.Tensor(self.inference_dataset["next"]["observation"][idx]).to(self.device),
                    "qpos": torch.Tensor(self.inference_dataset["next"]["qpos"][idx]).to(self.device),
                    "qvel": torch.Tensor(self.inference_dataset["next"]["qvel"][idx]).to(self.device),   
                }

            }

            
        else:
            data = self.inference_dataset.get_full_buffer()
        qpos = get_next("qpos", data)
        qvel = get_next("qvel", data)
        action = data["action"]
        if isinstance(qpos, torch.Tensor):
            qpos = qpos.cpu().detach().numpy()
            qvel = qvel.cpu().detach().numpy()
            action = action.cpu().detach().numpy()
        rewards = relabel(
            env,
            qpos,
            qvel,
            action,
            env.unwrapped.task,
            max_workers=self.max_workers,
            process_executor=self.process_executor,
        )
        env.close()

        td = {
            "reward": torch.tensor(rewards, dtype=torch.float32, device=self.device),
        }
        if "B" in data:
            td["B_vect"] = data["B"]
        else:
            td["next_obs"] = get_next("observation", data)
        inference_fn = getattr(self.model, self.inference_function, None)
        ctxs = inference_fn(**td).reshape(1, -1)
        return ctxs


@dataclasses.dataclass(kw_only=True)
class GoalWrapper(BaseHumEnvBenchWrapper):
    def goal_inference(self, goal_pose: torch.Tensor) -> torch.Tensor:
        next_obs = to_torch(goal_pose, device=self.device, dtype=self._dtype)
        ctx = self.unwrapped_model.goal_inference(next_obs=next_obs).reshape(1, -1)
        return ctx


@dataclasses.dataclass(kw_only=True)
class TrackingWrapper(BaseHumEnvBenchWrapper):
    def tracking_inference(self, next_obs: torch.Tensor | np.ndarray) -> torch.Tensor:
        next_obs = to_torch(next_obs, device=self.device, dtype=self._dtype)
        ctx = self.unwrapped_model.tracking_inference(next_obs=next_obs)
        return ctx


def to_torch(x: np.ndarray | torch.Tensor, device: torch.device | str, dtype: torch.dtype):
    if len(x.shape) == 1:
        # adding batch dimension
        x = x[None, ...]
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, device=device, dtype=dtype)
    else:
        x = x.to(dtype)
    return x


def _relabel_worker(
    x,
    model: mujoco.MjModel,
    reward_fn: RewardFunction,
):
    qpos, qvel, action = x
    assert len(qpos.shape) > 1
    assert qvel.shape[0] == qpos.shape[0]
    assert qvel.shape[0] == action.shape[0]
    rewards = np.zeros((qpos.shape[0], 1))
    for i in range(qpos.shape[0]):
        rewards[i] = reward_fn(model, qpos[i], qvel[i], action[i])
    return rewards


def relabel(
    env: Any,
    qpos: np.ndarray,
    qvel: np.ndarray,
    action: np.ndarray,
    reward_fn: RewardFunction,
    max_workers: int = 5,
    process_executor: bool = False,
    process_context: str = "spawn",
):
    chunk_size = int(np.ceil(qpos.shape[0] / max_workers))
    args = [(qpos[i : i + chunk_size], qvel[i : i + chunk_size], action[i : i + chunk_size]) for i in range(0, qpos.shape[0], chunk_size)]
    if max_workers == 1:
        result = [_relabel_worker(args[0], model=env.unwrapped.model, reward_fn=reward_fn)]
    else:
        if process_executor:
            import multiprocessing

            with ProcessPoolExecutor(
                max_workers=max_workers,
                mp_context=multiprocessing.get_context(process_context),
            ) as exe:
                f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
                result = exe.map(f, args)
        else:
            with ThreadPoolExecutor(max_workers=max_workers) as exe:
                f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn)
                result = exe.map(f, args)

    tmp = [r for r in result]
    return np.concatenate(tmp)
